package org.test4j.module.spring;

import org.junit.jupiter.api.extension.ExtensionContext;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.config.AutowireCapableBeanFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.test.context.TestContext;
import org.springframework.test.context.TestContextManager;
import org.test4j.Context;

import java.lang.ref.WeakReference;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

import static org.test4j.integration.junit5.JUnit5SpringHelper.getTestContextManager;
import static org.test4j.module.spring.SpringEnv.invokeSpringInitMethod;

/**
 * 和SpringEnv分开, 无spring依赖时避免NoClassDefFoundError异常
 *
 * @author wudarui
 */
public class SpringInit {
    /**
     * key: 测试类， value：AbstractApplicationContext实例
     */
    private static Map<Class, WeakReference<ApplicationContext>> springBeanFactories = new HashMap<>();
    /**
     * spring事务管理
     */
    private static ThreadLocal<TestContextManager> springTestContextManager = new ThreadLocal<>();


    /**
     * 获取当前测试实例的spring容器
     *
     * @return
     */
    public static Optional<ApplicationContext> getSpringContext() {
        WeakReference<ApplicationContext> reference = springBeanFactories.get(Context.currTestClass());
        if (reference == null || reference.get() == null) {
            return Optional.empty();
        } else {
            return Optional.of(reference.get());
        }
    }


    /**
     * 获得当前测试类spring容器中名称为beanName的spring bean
     *
     * @param beanName
     * @return
     */
    static <T> T getBeanByName(String beanName) {
        Object bean = getSpringContext().map(c -> {
            try {
                return c.getBean(beanName);
            } catch (NoSuchBeanDefinitionException e) {
                return null;
            }
        }).orElse(null);
        return (T) bean;
    }

    static <T> T getBeanByType(Class beanType) {
        Object bean = getSpringContext().map(c -> {
            try {
                return c.getBean(beanType);
            } catch (NoSuchBeanDefinitionException e) {
                return null;
            }
        }).orElse(null);
        return (T) bean;
    }

    /**
     * 设置当前测试实例的spring容器
     *
     * @param context
     */
    public static void setSpringContext(Class testClass, ApplicationContext context) {
        springBeanFactories.put(testClass, new WeakReference<>(context));
    }


    static void injectSpringBeans(Object testedObject) {
        if (!SpringEnv.isSpringEnv()) {
            return;
        }
        AutowireCapableBeanFactory beanFactory = getSpringContext().get().getAutowireCapableBeanFactory();
        beanFactory.autowireBeanProperties(testedObject, AutowireCapableBeanFactory.AUTOWIRE_NO, false);
        beanFactory.initializeBean(testedObject, testedObject.getClass().getSimpleName());
    }

    /**
     * 在测试spring容器启动前后执行
     * 1. 执行@BeforeSpringContext 方法
     * 2. 初始化测试实例注入
     * 3. 注册spring容器
     *
     * @param testInstance
     * @param context
     * @throws Exception
     */
    public static void doSpringInitial(Object testInstance, ExtensionContext context) throws Exception {
        TestContextManager contextManager = getTestContextManager(context);
        SpringInit.doSpringInitial(testInstance, contextManager);
    }

    /**
     * 在测试spring容器启动前后执行
     * 1. 执行@BeforeSpringContext 方法
     * 2. 初始化测试实例注入
     * 3. 注册spring容器
     *
     * @param testInstance
     * @param contextManager
     * @throws Exception
     */
    public static void doSpringInitial(Object testInstance, TestContextManager contextManager) throws Exception {
        invokeSpringInitMethod(testInstance);
        springTestContextManager.set(contextManager);
        contextManager.prepareTestInstance(testInstance);
        ApplicationContext applicationContext = getApplicationContext(contextManager);
        SpringInit.setSpringContext(testInstance.getClass(), applicationContext);
    }

    /**
     * 有些版本getTestContext禁止访问，所以这里反射调用
     *
     * @param contextManager
     * @return
     */
    public static ApplicationContext getApplicationContext(TestContextManager contextManager) {
        try {
            Method method = TestContextManager.class.getMethod("getTestContext");
            method.setAccessible(true);
            TestContext testContext = (TestContext) method.invoke(contextManager);
            return testContext.getApplicationContext();
        } catch (Exception e) {
            throw new RuntimeException("get Spring Application Context error: " + e.getMessage(), e);
        }
    }

    public static TestContextManager getSpringTestContextManager() {
        return springTestContextManager.get();
    }
}
